/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.dllib.nn

import com.intel.analytics.bigdl.dllib.nn.mkldnn.Equivalent
import org.scalatest.{FlatSpec, Matchers}
import com.intel.analytics.bigdl.dllib.tensor.Tensor
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath.TensorNumeric.NumericFloat
import com.intel.analytics.bigdl.dllib.utils.T
import com.intel.analytics.bigdl.dllib.utils.serializer.ModuleSerializationTest

import scala.util.Random

class FPNSpec extends FlatSpec with Matchers {
  "FPN updateOutput with None TopBlocks" should "work correctly" in {
    val in_channels_list = Array(1, 2, 4)
    val out_channels = 2
    val model = FPN[Float](in_channels_list, out_channels, topBlocks = 0)

    val feature1 = Tensor(
      T(T(0.10110152, 0.10345000, 0.04320979, 0.84362656,
          0.59594363, 0.97288179, 0.34699517, 0.54275155),
        T(0.93956870, 0.07543808, 0.50965708, 0.26184946,
          0.92378283, 0.83272308, 0.54440099, 0.56682664),
        T(0.53608388, 0.74091697, 0.53824615, 0.12760854,
          0.70029002, 0.85137993, 0.01918983, 0.10134047),
        T(0.61024511, 0.11725241, 0.46950370, 0.15163177,
          0.99792290, 0.50036842, 0.65618765, 0.76569498),
        T(0.31238246, 0.96460360, 0.23587847, 0.94086981,
          0.15270233, 0.44916826, 0.53412461, 0.19992995),
        T(0.14841199, 0.95466810, 0.89249784, 0.10235202,
          0.24293590, 0.83814293, 0.78163254, 0.94990700),
        T(0.50397956, 0.23095572, 0.12026519, 0.70295823,
          0.80230796, 0.31913465, 0.86270124, 0.67926580),
        T(0.93120003, 0.08011329, 0.30662805, 0.97467756,
          0.32988423, 0.90689850, 0.46856666, 0.66390038)))
      .reshape(Array(1, 1, 8, 8))

    val feature2 = Tensor(
      T(T(T(0.30143285, 0.63111430, 0.45092928, 0.22753167),
          T(0.80318344, 0.67537767, 0.14698678, 0.45962620),
          T(0.21663177, 0.89086282, 0.92865956, 0.89360029),
          T(0.49615270, 0.46269470, 0.73047608, 0.12438315)),
        T(T(0.75820625, 0.59779423, 0.61585987, 0.35782731),
          T(0.36951083, 0.35381025, 0.64314663, 0.75517660),
          T(0.30200917, 0.69998586, 0.29572868, 0.46342885),
          T(0.41677684, 0.26154006, 0.16909349, 0.94081402))))
      .reshape(Array(1, 2, 4, 4))

    val feature3 = Tensor(
      T(T(T(0.57270211, 0.25789189),
          T(0.79134840, 0.62564188)),
        T(T(0.27365083, 0.43420678),
          T(0.61281836, 0.23570287)),
        T(T(0.21393263, 0.50206852),
          T(0.50650394, 0.73282623)),
        T(T(0.20319027, 0.06753725),
          T(0.18215942, 0.36703324))))
      .reshape(Array(1, 4, 2, 2))

    val inner1_w = Tensor(
      T(T(T(T(0.24687862))),
        T(T(T(-0.56227243)))))
      .reshape(Array(1, 2, 1, 1, 1))
    val inner1_b = Tensor(T(0, 0))

    val inner2_w = Tensor(
      T(T(T(T(0.04691243)),
          T(T(-0.90420955))),
        T(T(T(1.09895408)),
          T(T(0.51624501)))))
      .reshape(Array(1, 2, 2, 1, 1))
    val inner2_b = Tensor(T(0, 0))

    val inner3_w = Tensor(
      T(T(T(T(0.25616819)),
          T(T(-0.74193102)),
          T(T(0.22137421)),
          T(T(0.53996474))),
        T(T(T(-0.30102068)),
          T(T(0.24491900)),
          T(T(-0.84143710)),
          T(T(-0.73395455)))))
      .reshape(Array(1, 2, 4, 1, 1))
    val inner3_b = Tensor(T(0, 0))

    val layer1_w = Tensor(
      T(T(T(T(-0.04048228, 0.16222215, 0.10794550),
            T(-0.34169874, -0.25080314, 0.11539066),
            T(-0.27039635, 0.19380659, 0.19993830)),
          T(T(0.12585402, -0.38708800, 0.09077036),
            T(0.12301302, -0.29949811, 0.12835038),
            T(-0.32869643, 0.37100095, -0.26665413))),
        T(T(T(-0.23543328, -0.24697217, 0.15786803),
            T(0.19520867, -0.06484443, 0.39382762),
            T(-0.09158209, -0.22267270, 0.23828101)),
          T(T(0.16857922, -0.26403868, -0.07582438),
            T(0.31187642, -0.14743957, 0.19229126),
            T(-0.00750843, -0.21541777, -0.04269919)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer1_b = Tensor(T(0, 0))

    val layer2_w = Tensor(
      T(T(T(T(-0.14214972, -0.17213514, -0.32127398),
            T(-0.23303765, -0.27284676, -0.05630624),
            T(-0.03209409, -0.16349350, -0.13884634)),
          T(T(0.05150193, -0.01451367, 0.29302871),
            T(0.38110715, 0.21102744, -0.01252702),
            T(-0.14486188, 0.39937240, 0.26671016))),
        T(T(T(-0.20462120, -0.03479487, -0.01640993),
            T(0.34504193, 0.11599201, 0.40438360),
            T(-0.17013551, 0.00606328, -0.14445123)),
          T(T(0.15805143, -0.06925225, -0.24366492),
            T(-0.16341771, -0.31556514, 0.03696010),
            T(0.07415351, -0.08760622, -0.17086124)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer2_b = Tensor(T(0, 0))

    val layer3_w = Tensor(
      T(T(T(T(-0.21088375, 0.39961314, 0.28634924),
            T(-0.09605905, -0.09238201, 0.29169798),
            T(-0.16913360, 0.34432471, 0.12923980)),
          T(T(0.15992212, 0.11829317, -0.08958191),
            T(0.29556727, 0.28719366, 0.35837567),
            T(0.35775679, 0.13369364, 0.22401685))),
        T(T(T(0.23750001, -0.26816195, -0.33834153),
            T(0.02364820, -0.28069261, -0.31661153),
            T(-0.05442283, 0.30038035, 0.23050475)),
          T(T(0.24013102, -0.04941136, -0.01676598),
            T(0.36672127, -0.14019510, -0.18527937),
            T(-0.21643242, -0.06160817, 0.14386815)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer3_b = Tensor(T(0, 0))

    val result1 = Tensor(
      T(T(T(T(-0.29643691, 0.32930288, 0.07719041, 0.20329267, -0.11702696,
              0.33030477, 0.19752777, 0.26074126),
            T(-0.04022884, -0.04050549, -0.17072679, 0.05824373, -0.18035993,
              -0.10781585, 0.21838233, 0.35475171),
            T(-0.14252800, -0.16825707, -0.28704056, -0.26278189, -0.19001812,
              0.20092483, 0.17245048, 0.46969670),
            T(-0.14943303, -0.45888224, 0.33286753, -0.42771903, 0.47255370,
              0.24915743, -0.21637592, 0.21200535),
            T(0.00808068, -0.16809230, -0.14534889, 0.29852685, 0.36068499,
              -0.19606119, -0.18463834, -0.19501874),
            T(-0.06999602, 0.55371714, -0.33532500, 0.29894528, 0.44789663,
              0.21802102, -0.32107252, -0.07110818),
            T(-0.19171244, 0.50532514, 0.00852559, -0.05432931, 0.56445789,
              -0.21175916, 0.01788443, 0.39967728),
            T(0.11412182, -0.05338766, 0.11950107, 0.33978215, 0.17466278,
              -0.22752701, 0.06036017, 0.51162905)),
          T(T(-0.18407047, -0.06274336, -0.19927005, -0.18067920, -0.12339569,
              -0.10210013, -0.13622473, 0.09764731),
            T(-0.21372095, -0.12506956, -0.10981269, -0.22901297, 0.15182146,
              0.01927174, -0.11695608, 0.25842062),
            T(-0.08454411, 0.00893094, 0.06784435, -0.36769092, 0.24231599,
              -0.07395025, -0.20645590, 0.32848105),
            T(0.07287200, 0.06812082, 0.00125982, -0.20824122, 0.26192454,
              -0.27801457, -0.43661070, 0.24346380),
            T(-0.08816936, -0.14699535, -0.50232911, 0.17301719, 0.39865568,
              0.21348065, 0.22505483, 0.28257197),
            T(0.12479763, -0.03339935, -0.48426947, 0.55722409, 0.36770806,
              -0.01681852, 0.11375013, 0.19888467),
            T(0.14368367, 0.01942967, -0.23314725, 0.41997516, 0.39273715,
              -0.40041974, -0.07516777, 0.04501504),
            T(-0.00356270, -0.15851222, 0.04203597, 0.33169088, -0.02303683,
              -0.42069232, -0.08245742, 0.06082898)))))

    val result2 = Tensor(
      T(T(T(T(0.67646873, 0.75461042, 0.88370752, 0.72522950),
            T(0.80561060, 1.40666068, 0.81269693, 0.72721291),
            T(0.42856935, 0.57526082, 0.84400183, 0.24381584),
            T(0.60819602, 0.32838598, 0.17468216, -0.05505963)),
          T(T(-0.41587284, -0.59085888, -0.50279200, -0.25322908),
            T(-0.42020139, -0.64106256, -0.23952308, -0.29740968),
            T(-0.31366453, -0.12451494, -0.13788190, 0.07498236),
            T(-0.31522152, -0.13974780, -0.06333419, 0.15230046)))))

    val result3 = Tensor(
      T(T(T(T(-0.60857159, -0.49706429),
            T(-0.44821957, -0.69798434)),
          T(T(0.11003723, 0.24464746),
            T(0.21994369, -0.22257896)))))

    val input = T(feature1, feature2, feature3)
    val expectedOutput = T(result1, result2, result3)

    model.parameters()._1(0).copy(inner3_w)
    model.parameters()._1(1).copy(inner3_b)
    model.parameters()._1(2).copy(inner2_w)
    model.parameters()._1(3).copy(inner2_b)
    model.parameters()._1(4).copy(inner1_w)
    model.parameters()._1(5).copy(inner1_b)
    model.parameters()._1(6).copy(layer3_w)
    model.parameters()._1(7).copy(layer3_b)
    model.parameters()._1(8).copy(layer2_w)
    model.parameters()._1(9).copy(layer2_b)
    model.parameters()._1(10).copy(layer1_w)
    model.parameters()._1(11).copy(layer1_b)

    val output = model.forward(input)

    Equivalent.nearequals(output.toTable.get[Tensor[Float]](1).get,
      expectedOutput.get[Tensor[Float]](1).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](2).get,
      expectedOutput.get[Tensor[Float]](2).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](3).get,
      expectedOutput.get[Tensor[Float]](3).get) should be(true)
  }

  "FPN updateOutput with MaxPooling TopBlocks" should "work correctly" in {
    val in_channels_list = Array(1, 2, 4)
    val out_channels = 2
    val model = FPN[Float](in_channels_list, out_channels, topBlocks = 1)

    val feature1 = Tensor(
      T(T(0.10110152, 0.10345000, 0.04320979, 0.84362656,
          0.59594363, 0.97288179, 0.34699517, 0.54275155),
        T(0.93956870, 0.07543808, 0.50965708, 0.26184946,
          0.92378283, 0.83272308, 0.54440099, 0.56682664),
        T(0.53608388, 0.74091697, 0.53824615, 0.12760854,
          0.70029002, 0.85137993, 0.01918983, 0.10134047),
        T(0.61024511, 0.11725241, 0.46950370, 0.15163177,
          0.99792290, 0.50036842, 0.65618765, 0.76569498),
        T(0.31238246, 0.96460360, 0.23587847, 0.94086981,
          0.15270233, 0.44916826, 0.53412461, 0.19992995),
        T(0.14841199, 0.95466810, 0.89249784, 0.10235202,
          0.24293590, 0.83814293, 0.78163254, 0.94990700),
        T(0.50397956, 0.23095572, 0.12026519, 0.70295823,
          0.80230796, 0.31913465, 0.86270124, 0.67926580),
        T(0.93120003, 0.08011329, 0.30662805, 0.97467756,
          0.32988423, 0.90689850, 0.46856666, 0.66390038)))
      .reshape(Array(1, 1, 8, 8))

    val feature2 = Tensor(
      T(T(T(0.30143285, 0.63111430, 0.45092928, 0.22753167),
          T(0.80318344, 0.67537767, 0.14698678, 0.45962620),
          T(0.21663177, 0.89086282, 0.92865956, 0.89360029),
          T(0.49615270, 0.46269470, 0.73047608, 0.12438315)),
        T(T(0.75820625, 0.59779423, 0.61585987, 0.35782731),
          T(0.36951083, 0.35381025, 0.64314663, 0.75517660),
          T(0.30200917, 0.69998586, 0.29572868, 0.46342885),
          T(0.41677684, 0.26154006, 0.16909349, 0.94081402))))
      .reshape(Array(1, 2, 4, 4))

    val feature3 = Tensor(
      T(T(T(0.57270211, 0.25789189),
          T(0.79134840, 0.62564188)),
        T(T(0.27365083, 0.43420678),
          T(0.61281836, 0.23570287)),
        T(T(0.21393263, 0.50206852),
          T(0.50650394, 0.73282623)),
        T(T(0.20319027, 0.06753725),
          T(0.18215942, 0.36703324))))
      .reshape(Array(1, 4, 2, 2))

    val inner1_w = Tensor(
      T(T(T(T(1.00586259))),
        T(T(T(0.53887093)))))
      .reshape(Array(1, 2, 1, 1, 1))
    val inner1_b = Tensor(T(0, 0))

    val inner2_w = Tensor(
      T(T(T(T(-0.57429278)),
          T(T(-0.24179715))),
        T(T(T(0.67793036)),
          T(T(-0.94123614)))))
      .reshape(Array(1, 2, 2, 1, 1))
    val inner2_b = Tensor(T(0, 0))

    val inner3_w = Tensor(
      T(T(T(T(0.17291552)),
          T(T(-0.05612940)),
          T(T(0.36356455)),
          T(T(-0.79740608))),
        T(T(T(0.72361153)),
          T(T(-0.31787324)),
          T(T(0.04836881)),
          T(T(0.45409185)))))
      .reshape(Array(1, 2, 4, 1, 1))
    val inner3_b = Tensor(T(0, 0))

    val layer1_w = Tensor(
        T(T(T(T(0.06878856, -0.35743117, 0.31631619),
              T(-0.14119744, 0.30255783, 0.14926106),
              T(-0.38726792, 0.04510748, -0.36082375)),
            T(T(-0.23815951, -0.38959473, 0.05021074),
              T(0.19526446, -0.35286927, -0.39654526),
              T(-0.00148910, -0.24063437, -0.29699990))),
          T(T(T(0.07476860, -0.02564883, 0.09487671),
              T(-0.01090044, 0.23407942, 0.24647915),
              T(-0.38014463, 0.33695221, 0.40465516)),
            T(T(-0.00955230, 0.37457061, 0.10492092),
              T(-0.12585542, 0.21253753, 0.10564721),
              T(0.07659015, -0.03546333, -0.07322484)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer1_b = Tensor(T(0, 0))

    val layer2_w = Tensor(
      T(T(T(T(0.32807761, -0.33899420, 0.06800264),
            T(0.07076809, 0.14122516, 0.10424459),
            T(-0.03563347, -0.04193285, 0.26541936)),
          T(T(-0.33386642, 0.38784570, -0.05316493),
            T(-0.37846458, 0.03199247, -0.04221478),
            T(-0.38094023, 0.21109033, 0.18027461))),
        T(T(T(0.08262184, 0.38594717, 0.33632153),
            T(-0.24012834, -0.19122560, 0.35697746),
            T(-0.18635783, -0.16684684, -0.17575860)),
          T(T(-0.24746780, -0.08889309, 0.01367763),
            T(-0.12756592, -0.38951454, -0.28759271),
            T(0.29410106, 0.03703991, 0.06836116)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer2_b = Tensor(T(0, 0))

    val layer3_w = Tensor(
      T(T(T(T(-0.37369999, -0.19362454, -0.32376695),
            T(0.27765042, -0.03229478, -0.27471265),
            T(0.11516148, 0.22647744, -0.15064511)),
          T(T(0.23695397, 0.32747757, -0.08015823),
            T(0.20880389, 0.34441620, 0.06963590),
            T(-0.18623261, -0.23078077, -0.24822637))),
        T(T(T(-0.00328833, -0.06870756, 0.37950665),
            T(0.39529461, -0.23882844, 0.33771485),
            T(-0.37432045, -0.18209046, 0.07186159)),
          T(T(0.23758322, 0.39008999, -0.22646688),
            T(-0.02726471, 0.03744176, 0.02614474),
            T(0.23741430, -0.14411601, 0.32169968)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer3_b = Tensor(T(0, 0))

    val result1 = Tensor(
      T(T(T(T(-0.06027000, -0.63544929, -0.53757948, -0.50052786, -0.13087437,
              -0.46421096, -0.50183028, -0.16715574),
            T(-0.57532835, -0.84933513, -0.53883237, -0.96058822, -0.03360630,
              -0.31204289, -0.08640554, -0.05677760),
            T(-1.44774520, -0.81390339, -0.79381657, -0.59526253, 0.18857586,
              -0.44416034, -0.48765731, -0.42349899),
            T(-1.59604609, -1.46630657, -1.98059797, -0.45855987, -0.36292380,
              -0.69029158, -0.14087702, 0.05812502),
            T(-1.48957527, -1.13170886, -1.67776370, -1.03186679, -1.87643552,
              -1.18758655, -1.37157834, -0.78431809),
            T(-0.60699046, -1.50136328, -1.11882496, -2.10094833, -2.03609610,
              -1.95528400, -1.52000046, -0.84520984),
            T(-0.53370321, -1.59671521, -1.87052989, -1.35319352, -2.15164757,
              -1.38132536, -0.71602869, -0.96002018),
            T(-0.67362547, -0.74905103, -0.54873347, -0.96106482, -1.61241412,
              -0.56977570, -0.16225342, -0.17618783)),
          T(T(-0.05780253, -0.46292540, 0.08270258, 0.49520209, 0.78262019,
              0.43522203, 0.31832328, 0.13581192),
            T(0.09274495, -0.03541034, 0.15171435, 0.59660333, 1.20995283,
              0.39199278, -0.01640499, 0.24762793),
            T(0.22356896, 0.16624556, 0.27169892, 0.64676547, 1.07555902,
              0.24264205, 0.10081939, 0.06938417),
            T(1.02827179, 0.52947670, 0.12263671, 0.52516317, -0.09253020,
              0.19734971, 0.02325941, -0.04252122),
            T(1.19363809, 1.01430440, 0.02628537, 0.15349422, 0.21920983,
              0.51211888, 0.26795861, 0.11892501),
            T(0.80922467, 0.72818476, 0.66089183, 0.58932924, 0.70654577,
              1.04901230, 1.27366579, 0.45431411),
            T(0.69929636, 0.23726486, 1.02899837, 0.91582721, 0.73064196,
              1.17346644, 0.96129149, 0.62078124),
            T(0.68208826, 0.32198429, 0.80620545, 0.92930055, 0.98262572,
              0.85965669, 0.12026373, 0.21281539)))))

    val result2 = Tensor(
      T(T(T(T(0.01130757, -0.25607330, -0.60133040, 0.14051536),
            T(-0.08047363, -0.12201009, -0.46985039, -0.09533735),
            T(0.44763428, -0.26126349, -0.90198398, -0.72634822),
            T(0.03062394, -0.13872625, -0.11850480, -0.42137370)),
          T(T(0.06946735, 0.40965801, 0.48238945, 0.06515967),
            T(-0.61845332, 0.06515414, 0.47206032, 0.74106240),
            T(-0.73383999, -0.66845274, -0.55763161, 0.20043206),
            T(-0.57443708, -1.08336210, -0.94348121, -0.44871095)))))

    val result3 = Tensor(
      T(T(T(T(-0.11979339, -0.07183948),
            T(0.26843292, 0.44522521)),
          T(T(0.16508943, -0.07747011),
            T(0.22362739, 0.18027946)))))

    val result4 = Tensor(
        T(T(T(T(-0.11979339)),
            T(T(0.16508943)))))

    val input = T(feature1, feature2, feature3)
    val expectedOutput = T(result1, result2, result3, result4)

    model.parameters()._1(0).copy(inner3_w)
    model.parameters()._1(1).copy(inner3_b)
    model.parameters()._1(2).copy(inner2_w)
    model.parameters()._1(3).copy(inner2_b)
    model.parameters()._1(4).copy(inner1_w)
    model.parameters()._1(5).copy(inner1_b)
    model.parameters()._1(6).copy(layer3_w)
    model.parameters()._1(7).copy(layer3_b)
    model.parameters()._1(8).copy(layer2_w)
    model.parameters()._1(9).copy(layer2_b)
    model.parameters()._1(10).copy(layer1_w)
    model.parameters()._1(11).copy(layer1_b)

    val output = model.forward(input)

    Equivalent.nearequals(output.toTable.get[Tensor[Float]](1).get,
      expectedOutput.get[Tensor[Float]](1).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](2).get,
      expectedOutput.get[Tensor[Float]](2).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](3).get,
      expectedOutput.get[Tensor[Float]](3).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](4).get,
      expectedOutput.get[Tensor[Float]](4).get) should be(true)
  }

  "FPN updateOutput with P6P7 TopBlocks not use P5" should "work correctly" in {
    val in_channels_list = Array(1, 2, 4)
    val out_channels = 2
    val model = FPN[Float](in_channels_list, out_channels, topBlocks = 2,
      inChannelsOfP6P7 = 4, outChannelsOfP6P7 = 2) // inChannelsP6P7 != outChannelsP6P7

    val feature1 = Tensor(
      T(T(0.10110152, 0.10345000, 0.04320979, 0.84362656,
          0.59594363, 0.97288179, 0.34699517, 0.54275155),
        T(0.93956870, 0.07543808, 0.50965708, 0.26184946,
          0.92378283, 0.83272308, 0.54440099, 0.56682664),
        T(0.53608388, 0.74091697, 0.53824615, 0.12760854,
          0.70029002, 0.85137993, 0.01918983, 0.10134047),
        T(0.61024511, 0.11725241, 0.46950370, 0.15163177,
          0.99792290, 0.50036842, 0.65618765, 0.76569498),
        T(0.31238246, 0.96460360, 0.23587847, 0.94086981,
          0.15270233, 0.44916826, 0.53412461, 0.19992995),
        T(0.14841199, 0.95466810, 0.89249784, 0.10235202,
          0.24293590, 0.83814293, 0.78163254, 0.94990700),
        T(0.50397956, 0.23095572, 0.12026519, 0.70295823,
          0.80230796, 0.31913465, 0.86270124, 0.67926580),
        T(0.93120003, 0.08011329, 0.30662805, 0.97467756,
          0.32988423, 0.90689850, 0.46856666, 0.66390038)))
      .reshape(Array(1, 1, 8, 8))

    val feature2 = Tensor(
      T(T(T(0.30143285, 0.63111430, 0.45092928, 0.22753167),
          T(0.80318344, 0.67537767, 0.14698678, 0.45962620),
          T(0.21663177, 0.89086282, 0.92865956, 0.89360029),
          T(0.49615270, 0.46269470, 0.73047608, 0.12438315)),
        T(T(0.75820625, 0.59779423, 0.61585987, 0.35782731),
          T(0.36951083, 0.35381025, 0.64314663, 0.75517660),
          T(0.30200917, 0.69998586, 0.29572868, 0.46342885),
          T(0.41677684, 0.26154006, 0.16909349, 0.94081402))))
      .reshape(Array(1, 2, 4, 4))

    val feature3 = Tensor(
      T(T(T(0.57270211, 0.25789189),
          T(0.79134840, 0.62564188)),
        T(T(0.27365083, 0.43420678),
          T(0.61281836, 0.23570287)),
        T(T(0.21393263, 0.50206852),
          T(0.50650394, 0.73282623)),
        T(T(0.20319027, 0.06753725),
          T(0.18215942, 0.36703324))))
      .reshape(Array(1, 4, 2, 2))

    val inner1_w = Tensor(
      T(T(T(T(1.47257316))),
        T(T(T(0.57414114)))))
      .reshape(Array(1, 2, 1, 1, 1))
    val inner1_b = Tensor(T(0, 0))

    val inner2_w = Tensor(
      T(T(T(T(0.45074105)),
          T(T(-0.30885106))),
        T(T(T(-0.08952701)),
          T(T(-0.26140732)))))
      .reshape(Array(1, 2, 2, 1, 1))
    val inner2_b = Tensor(T(0, 0))

    val inner3_w = Tensor(
      T(T(T(T(-0.30031908)),
          T(T(-0.58480197)),
          T(T(0.59235269)),
          T(T(-0.13991892))),
        T(T(T(0.62555033)),
          T(T(0.72914702)),
          T(T(-0.44170576)),
          T(T(0.49929196)))))
      .reshape(Array(1, 2, 4, 1, 1))
    val inner3_b = Tensor(T(0, 0))

    val layer1_w = Tensor(
      T(T(T(T(-0.04400888, -0.35957703, -0.02164334),
            T(0.40402526, 0.36285782, 0.31368673),
            T(-0.35616416, -0.21952458, 0.37052453)),
          T(T(0.13778913, -0.30064595, -0.36663383),
            T(0.37170672, 0.32204062, -0.07368714),
            T(0.19972658, -0.39074513, -0.38521481))),
        T(T(T(0.05121413, 0.23705125, 0.13029754),
            T(-0.29272887, 0.08022153, -0.16771419),
            T(-0.38660547, -0.30105561, -0.17050056)),
          T(T(-0.38432136, 0.04626641, 0.20397991),
            T(-0.24799925, -0.34601510, 0.23324311),
            T(0.39426655, -0.28500557, 0.33542544)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer1_b = Tensor(T(0, 0))

    val layer2_w = Tensor(
      T(T(T(T(-0.26440758, -0.40462878, 0.35458815),
            T(-0.27700549, -0.24707370, 0.14012802),
            T(-0.02187592, 0.12944663, 0.15989727)),
          T(T(0.25460601, 0.33005655, 0.19840294),
            T(0.08936363, -0.01533994, -0.10784483),
            T(0.14462578, -0.32323214, -0.31677228))),
        T(T(T(0.24838877, -0.30633825, -0.14952859),
            T(-0.10827839, -0.09704661, 0.01009622),
            T(-0.17448114, 0.40084583, 0.25651050)),
          T(T(0.02460378, 0.31060696, 0.29154462),
            T(0.04250652, 0.06705299, 0.10902947),
            T(-0.21223937, 0.02931285, -0.20978554)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer2_b = Tensor(T(0, 0))

    val layer3_w = Tensor(
      T(T(T(T(0.28868508, 0.34335995, -0.21298549),
            T(0.13598031, 0.14855188, 0.16282564),
            T(0.24104220, 0.19631046, 0.28864717)),
          T(T(0.17355555, 0.17067927, 0.34322286),
            T(-0.32470348, -0.15039983, -0.37904710),
            T(-0.32140541, -0.31889421, -0.34283394))),
        T(T(T(-0.27881464, 0.32479310, 0.33741760),
            T(-0.04920617, 0.38263774, 0.37934089),
            T(-0.07421857, 0.28872919, -0.24625073)),
          T(T(-0.07631743, 0.15071201, 0.20164257),
            T(-0.02279785, 0.24347979, -0.33499616),
            T(0.25867003, 0.11343688, -0.39765364)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer3_b = Tensor(T(0, 0))

    val p6_w = Tensor(
      T(T(T(T(0.18047711, 0.27739489, -0.09207258),
            T(-0.06993897, 0.10390741, 0.16100934),
            T(0.09704280, -0.11835672, 0.11216679)),
          T(T(0.07162413, -0.03068006, 0.21995866),
            T(-0.01902044, -0.23496029, 0.08649853),
            T(0.26076275, 0.12215102, -0.24565969)),
          T(T(-0.26359978, 0.28385252, -0.08561571),
            T(0.08719349, -0.03602475, -0.14762157),
            T(0.04393122, 0.15552819, -0.19104180)),
          T(T(-0.26122716, -0.23169036, 0.04371125),
            T(0.17757964, 0.18492216, 0.18820083),
            T(-0.18676189, 0.02983525, -0.04895349))),
        T(T(T(-0.13829032, 0.28245789, -0.10234657),
            T(0.14773294, 0.28724921, -0.09669375),
            T(0.11997268, -0.19171268, 0.17503896)),
          T(T(-0.01335889, 0.27340567, 0.15419030),
            T(-0.27378151, 0.08404601, 0.20571443),
            T(0.03300169, -0.07807332, 0.27800083)),
          T(T(0.22714883, 0.10564631, 0.10429975),
            T(-0.15422256, 0.12877643, -0.07962382),
            T(0.05750173, 0.24986815, -0.24631210)),
          T(T(0.14758101, -0.14909469, -0.02427217),
            T(-0.22774965, 0.24656773, -0.09009914),
            T(-0.08819377, -0.14353877, 0.02373797)))))
    val p6_b = Tensor(T(0, 0))

    val p7_w = Tensor(
      T(T(T(T(-0.00312749, 0.15891045, -0.06029734),
            T(0.32925665, -0.28568161, -0.22913636),
            T(0.25732505, -0.02756864, 0.22088635)),
          T(T(-0.19972500, -0.35011724, 0.36097509),
            T(0.13380224, 0.31481904, -0.34110975),
            T(0.00228858, -0.30160201, 0.39911568))),
        T(T(T(0.10945880, 0.04096296, 0.34124666),
            T(0.21367294, -0.40180174, 0.02459040),
            T(0.01582986, -0.35805190, 0.19427061)),
          T(T(-0.05247149, 0.03913751, -0.30283454),
            T(-0.06808761, 0.30844611, -0.25382966),
            T(0.39491993, -0.16227409, -0.33975506)))))
    val p7_b = Tensor(T(0, 0))

    val result1 = Tensor(
      T(T(T(T(-0.82784057, -0.23263724, 0.12529419, 0.92917746, 1.29870367,
              0.59191561, 0.72206885, 0.21600738),
            T(0.01256093, 0.10397450, -0.63768029, 0.11551863, 0.49406019,
              0.04690269, 0.59401810, 0.51324034),
            T(-0.57496190, 0.80619323, 0.21891174, 0.46305850, -0.29128370,
              -0.10264862, -0.19434255, -0.98540932),
            T(-0.68715960, -0.63924152, -0.28786534, 0.21412316, 0.14116248,
              0.42578453, 0.50156069, 0.45927033),
            T(-0.50619870, 0.29627720, -0.08331068, 0.55051923, 0.87432826,
              0.22587368, 0.05705506, -0.60149169),
            T(-1.08125985, -0.19702393, 0.76295900, 0.24722506, 0.03166249,
              0.35292828, 0.89928788, 0.76004601),
            T(-1.24697328, -0.93874627, -0.32030576, 0.52993482, 0.88237661,
              -0.27623445, -0.30513218, -0.13993274),
            T(-0.04807660, 0.65625536, 0.34513134, 0.66153854, 1.02909780,
              0.85668772, 0.74089944, 0.41518468)),
          T(T(-0.43007800, -0.08928002, -0.18057445, -0.53732187, -1.18706334,
              -1.21151233, -1.24502730, -0.81248140),
            T(-0.41591629, -1.23449159, -0.37417489, -0.38646969, -0.35627022,
              -0.87439328, -0.63093376, -0.38696021),
            T(-0.19757175, -0.80025971, -0.61968642, -0.58690047, -0.29969466,
              -0.66096216, -0.69664645, -0.61171681),
            T(0.36569861, -0.43666506, -0.23078403, -1.02591038, -0.34318402,
              -1.12366092, -1.22326660, -0.95382887),
            T(0.22266409, -0.61877912, -1.04867685, -0.58774620, -0.58317888,
              -1.11619925, -1.20713544, -1.40455294),
            T(0.12496996, -0.14055842, -0.44808233, -0.85750657, -0.82033932,
              -0.74288636, -1.17979848, -1.27777565),
            T(-0.13870570, 0.01701410, -0.15212707, -1.16827607, -0.73849547,
              -0.94292432, -0.49970376, -0.77397305),
            T(0.20336530, -0.83974218, -0.26997119, -0.33915856, -0.64899278,
              -0.23277763, -0.45086405, -0.36021605)))))

    val result2 = Tensor(
      T(T(T(T(-0.24274831, -0.00741440, 0.02368479, 0.02028912),
            T(-0.25623968, 0.24358158, 0.28456029, 0.18310754),
            T(-0.31125316, 0.21291766, 0.18210757, -0.08399948),
            T(0.36639184, 0.86005092, 0.24853867, -0.11300255)),
          T(T(0.01302569, -0.06152091, -0.11953454, 0.00511467),
            T(0.14370050, 0.07275833, 0.22634801, 0.10956798),
            T(0.11319179, 0.22101365, 0.10727698, -0.24114035),
            T(0.63779628, 0.32476583, 0.01623568, 0.07922356)))))

    val result3 = Tensor(
      T(T(T(T(-0.68622679, -0.72637987),
            T(-0.19554541, -0.29644468)),
          T(T(-0.24871959, 0.35015780),
            T(0.00696990, 0.17378137)))))

    val result4 = Tensor(
      T(T(T(T(-0.03753495)),
          T(T(0.18758345)))))

    val result5 = Tensor(
      T(T(T(T(0.05905484)),
          T(T(0.05785938)))))

    val input = T(feature1, feature2, feature3)
    val expectedOutput = T(result1, result2, result3, result4, result5)

    model.parameters()._1(0).copy(inner3_w)
    model.parameters()._1(1).copy(inner3_b)
    model.parameters()._1(2).copy(inner2_w)
    model.parameters()._1(3).copy(inner2_b)
    model.parameters()._1(4).copy(inner1_w)
    model.parameters()._1(5).copy(inner1_b)

    model.parameters()._1(6).copy(p6_w)
    model.parameters()._1(7).copy(p6_b)
    model.parameters()._1(8).copy(p7_w)
    model.parameters()._1(9).copy(p7_b)

    model.parameters()._1(10).copy(layer3_w)
    model.parameters()._1(11).copy(layer3_b)
    model.parameters()._1(12).copy(layer2_w)
    model.parameters()._1(13).copy(layer2_b)
    model.parameters()._1(14).copy(layer1_w)
    model.parameters()._1(15).copy(layer1_b)

    val output = model.forward(input)

    Equivalent.nearequals(output.toTable.get[Tensor[Float]](1).get,
      expectedOutput.get[Tensor[Float]](1).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](2).get,
      expectedOutput.get[Tensor[Float]](2).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](3).get,
      expectedOutput.get[Tensor[Float]](3).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](4).get,
      expectedOutput.get[Tensor[Float]](4).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](5).get,
      expectedOutput.get[Tensor[Float]](5).get) should be(true)
  }

  "FPN updateOutput with P6P7 TopBlocks use P5" should "work correctly" in {
    val in_channels_list = Array(1, 2, 4)
    val out_channels = 2
    val model = FPN[Float](in_channels_list, out_channels, topBlocks = 2,
      inChannelsOfP6P7 = 2, outChannelsOfP6P7 = 2) // inChannelsP6P7 == outChannelsP6P7

    val feature1 = Tensor(
      T(T(0.10110152, 0.10345000, 0.04320979, 0.84362656,
        0.59594363, 0.97288179, 0.34699517, 0.54275155),
        T(0.93956870, 0.07543808, 0.50965708, 0.26184946,
          0.92378283, 0.83272308, 0.54440099, 0.56682664),
        T(0.53608388, 0.74091697, 0.53824615, 0.12760854,
          0.70029002, 0.85137993, 0.01918983, 0.10134047),
        T(0.61024511, 0.11725241, 0.46950370, 0.15163177,
          0.99792290, 0.50036842, 0.65618765, 0.76569498),
        T(0.31238246, 0.96460360, 0.23587847, 0.94086981,
          0.15270233, 0.44916826, 0.53412461, 0.19992995),
        T(0.14841199, 0.95466810, 0.89249784, 0.10235202,
          0.24293590, 0.83814293, 0.78163254, 0.94990700),
        T(0.50397956, 0.23095572, 0.12026519, 0.70295823,
          0.80230796, 0.31913465, 0.86270124, 0.67926580),
        T(0.93120003, 0.08011329, 0.30662805, 0.97467756,
          0.32988423, 0.90689850, 0.46856666, 0.66390038)))
      .reshape(Array(1, 1, 8, 8))

    val feature2 = Tensor(
      T(T(T(0.30143285, 0.63111430, 0.45092928, 0.22753167),
        T(0.80318344, 0.67537767, 0.14698678, 0.45962620),
        T(0.21663177, 0.89086282, 0.92865956, 0.89360029),
        T(0.49615270, 0.46269470, 0.73047608, 0.12438315)),
        T(T(0.75820625, 0.59779423, 0.61585987, 0.35782731),
          T(0.36951083, 0.35381025, 0.64314663, 0.75517660),
          T(0.30200917, 0.69998586, 0.29572868, 0.46342885),
          T(0.41677684, 0.26154006, 0.16909349, 0.94081402))))
      .reshape(Array(1, 2, 4, 4))

    val feature3 = Tensor(
      T(T(T(0.57270211, 0.25789189),
        T(0.79134840, 0.62564188)),
        T(T(0.27365083, 0.43420678),
          T(0.61281836, 0.23570287)),
        T(T(0.21393263, 0.50206852),
          T(0.50650394, 0.73282623)),
        T(T(0.20319027, 0.06753725),
          T(0.18215942, 0.36703324))))
      .reshape(Array(1, 4, 2, 2))

    val inner1_w = Tensor(
      T(T(T(T(-0.12393522))),
        T(T(T(-0.49485075)))))
      .reshape(Array(1, 2, 1, 1, 1))
    val inner1_b = Tensor(T(0, 0))

    val inner2_w = Tensor(
      T(T(T(T(-0.95695794)),
          T(T(0.55932796))),
        T(T(T(0.22264385)),
          T(T(0.64771581)))))
      .reshape(Array(1, 2, 2, 1, 1))
    val inner2_b = Tensor(T(0, 0))

    val inner3_w = Tensor(
      T(T(T(T(0.47477275)),
          T(T(0.04092562)),
          T(T(-0.01725465)),
          T(T(-0.34568024))),
        T(T(T(-0.79893148)),
          T(T(-0.66726011)),
          T(T(-0.14056665)),
          T(T(-0.75817424)))))
      .reshape(Array(1, 2, 4, 1, 1))
    val inner3_b = Tensor(T(0, 0))

    val layer1_w = Tensor(
      T(T(T(T(-0.32906294, -0.08600309, -0.38722333),
            T(-0.29580453, 0.40037566, -0.16175754),
            T(0.25444168, 0.11281389, -0.07697448)),
          T(T(-0.20765188, -0.30854949, 0.33915347),
            T(-0.05911121, -0.20772298, 0.36908209),
            T(0.39145410, 0.07839337, 0.09654927))),
        T(T(T(-0.26997358, -0.21366502, -0.14226845),
            T(-0.05312893, -0.10671085, 0.37542689),
            T(-0.28042397, 0.02129859, 0.33310878)),
          T(T(-0.39731082, -0.22968259, 0.31097382),
            T(0.24397695, -0.38017231, 0.40436870),
            T(0.25588512, -0.12146497, 0.10941350)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer1_b = Tensor(T(0, 0))

    val layer2_w = Tensor(
      T(T(T(T(-0.13642263, 0.21656078, -0.01871455),
            T(-0.20130268, -0.25516552, 0.34926140),
            T(0.13896102, -0.37103790, 0.23734450)),
          T(T(0.08139789, 0.37057930, 0.38387370),
            T(0.34906447, 0.30327201, 0.23043340),
            T(0.04161811, -0.07575810, 0.25803828))),
        T(T(T(-0.34966460, -0.22834912, 0.01767731),
            T(-0.16592246, -0.36947623, -0.01893327),
            T(0.18922144, -0.23139042, -0.28582191)),
          T(T(-0.02167633, -0.23346797, 0.00187096),
            T(0.14594424, 0.39863366, 0.11338776),
            T(-0.33135366, -0.30160487, -0.29802644)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer2_b = Tensor(T(0, 0))

    val layer3_w = Tensor(
      T(T(T(T(-0.00817868, 0.29565400, -0.03227356),
            T(-0.11617559, 0.20846748, -0.03688866),
            T(-0.05434576, -0.04842332, 0.02425647)),
          T(T(0.18390912, -0.00540081, 0.29155219),
            T(0.23329431, -0.11891335, 0.24823219),
            T(0.23775083, -0.04294857, -0.34929958))),
        T(T(T(0.12607539, 0.23896956, 0.01926240),
            T(-0.09790298, -0.30780315, 0.14969867),
            T(0.00337335, -0.31408104, -0.37880355)),
          T(T(-0.25468409, 0.14823782, 0.40019959),
            T(-0.03427723, -0.04853129, 0.02510184),
            T(0.25904632, 0.34354115, 0.10385382)))))
      .reshape(Array(1, 2, 2, 3, 3))
    val layer3_b = Tensor(T(0, 0))

    val p6_w = Tensor(
      T(T(T(T(-0.29818240, -0.22222301, -0.27881575),
            T(0.19775295, -0.04910746, -0.11908785),
            T(0.27843529, 0.29707819, 0.30488032)),
          T(T(-0.12720180, 0.08535665, -0.33813587),
            T(-0.02545372, -0.38678339, 0.11843002),
            T(-0.06442717, -0.00726947, 0.14210951))),
        T(T(T(0.09151843, -0.08247298, 0.07027003),
            T(0.23849773, 0.26687491, -0.03779584),
            T(0.35821044, 0.17980134, -0.07383940)),
          T(T(0.31192303, 0.35286152, -0.00549397),
            T(-0.33862600, -0.27117044, 0.37536448),
            T(-0.23844577, -0.23443575, -0.35118651)))))
    val p6_b = Tensor(T(0, 0))

    val p7_w = Tensor(
      T(T(T(T(0.27193576, 0.09538496, 0.09932923),
            T(0.23505199, -0.08323237, -0.35580242),
            T(0.06272587, -0.19060957, 0.32343888)),
          T(T(0.24513763, -0.01483554, 0.25192779),
            T(0.26561451, 0.05530944, 0.30232042),
            T(0.30819184, 0.09326428, 0.12093598))),
        T(T(T(0.32881397, -0.17656034, -0.26700664),
            T(-0.16808785, 0.38506639, -0.15014803),
            T(0.21106857, 0.21199214, -0.31056783)),
          T(T(-0.16920617, 0.12196451, 0.08281082),
            T(0.22818404, -0.17261851, -0.29054090),
            T(-0.21099238, -0.04546800, -0.15372574)))))
    val p7_b = Tensor(T(0, 0))

    val result1 = Tensor(
      T(T(T(T(-0.05875708, -0.24495505, -0.40095121, -0.36195034, -0.34048805,
              -0.42995015, -0.37165138, -0.12609129),
            T(-0.23046799, -0.59160781, -0.81221211, -0.36748683, -0.22889382,
              -0.15329736, -0.12485854, 0.45016751),
            T(-0.38892403, -0.57008761, -0.31188434, -0.44806325, -0.14688066,
              0.00538448, -0.08428087, 0.16207692),
            T(-0.20208777, -0.19664873, -0.20870245, -0.44272959, -0.46557492,
              -0.41775605, -0.76812929, -0.44619679),
            T(-0.20316558, -0.05055160, -0.55165714, -0.40639529, -0.49637964,
              -0.66946077, -0.75888383, -0.29708627),
            T(-0.65783459, -0.16802897, -0.41265154, -0.02700083, -0.49787420,
              -0.34201804, -0.01878840, 0.55896097),
            T(-0.59010309, -0.48106664, -0.34888858, -0.17606093, -0.57338041,
              -0.27389777, 0.12463056, 0.86562246),
            T(-0.09638648, 0.05499656, -0.18625061, 0.50743264, 0.32407704,
              0.19390954, 0.34793308, 0.43689337)),
          T(T(0.25326055, -0.47114417, -0.50304997, -0.12190270, -0.38302276,
              -0.16330689, -0.28812358, 0.01487039),
            T(-0.04750883, -0.58049947, -0.28602204, -0.01689222, 0.16504316,
              -0.08511922, -0.14781611, 0.38237956),
            T(-0.29861927, -0.31464750, -0.18262222, 0.11181816, 0.16474791,
              0.07716653, -0.16424689, 0.33493862),
            T(0.09711685, -0.16085583, -0.12101643, -0.05731618, -0.25519797,
              -0.15982063, -0.20263793, 0.06042653),
            T(0.07343097, -0.05079371, -0.65273732, -0.35203332, -0.65474921,
              -0.31770957, -0.47713327, 0.37339261),
            T(-0.37637535, -0.02805071, -0.12414282, -0.43823543, -0.33148623,
              0.35421544, 0.32711336, 0.69742543),
            T(-0.16029462, -0.29591912, -0.06338350, -0.28330535, -0.57328767,
              0.75277287, 0.65203953, 0.65192145),
            T(0.10750079, -0.11812737, -0.21442677, 0.17212176, 0.05892290,
              0.85415667, 0.09521016, 0.31057179)))))

    val result2 = Tensor(
      T(T(T(T(-0.26799202, -0.16586389, -0.32381740, -0.16350438),
            T(-0.57322353, -0.51086164, -0.43073779, -0.17672254),
            T(-1.16777086, -1.43536067, -0.29853198, -0.48934227),
            T(-1.03757823, -1.56285250, -0.94800329, -0.71664643)),
          T(T(0.13909623, 0.00816379, -0.06897804, -0.08329182),
            T(0.35677588, 0.80284458, 0.61186469, 0.29867059),
            T(0.12059141, 0.79457754, 0.50915569, -0.02383049),
            T(-0.32545245, -0.54294991, -0.02146160, 0.02067354)))))

    val result3 = Tensor(
      T(T(T(T(0.38887578, -0.39928365),
            T(-0.16387880, -0.28960186)),
          T(T(-0.72778845, -0.72761118),
            T(-0.35890821, 0.18019901)))))

    val result4 = Tensor(
      T(T(T(T(0.11501697)),
          T(T(0.05588362)))))

    val result5 = Tensor(
      T(T(T(T(-0.00648224)),
          T(T(0.03464263)))))

    val input = T(feature1, feature2, feature3)
    val expectedOutput = T(result1, result2, result3, result4, result5)

    model.parameters()._1(0).copy(inner3_w)
    model.parameters()._1(1).copy(inner3_b)
    model.parameters()._1(2).copy(inner2_w)
    model.parameters()._1(3).copy(inner2_b)
    model.parameters()._1(4).copy(layer3_w)
    model.parameters()._1(5).copy(layer3_b)

    model.parameters()._1(6).copy(inner1_w)
    model.parameters()._1(7).copy(inner1_b)
    model.parameters()._1(8).copy(p6_w)
    model.parameters()._1(9).copy(p6_b)

    model.parameters()._1(10).copy(p7_w)
    model.parameters()._1(11).copy(p7_b)
    model.parameters()._1(12).copy(layer2_w)
    model.parameters()._1(13).copy(layer2_b)
    model.parameters()._1(14).copy(layer1_w)
    model.parameters()._1(15).copy(layer1_b)

    val output = model.forward(input)

    Equivalent.nearequals(output.toTable.get[Tensor[Float]](1).get,
      expectedOutput.get[Tensor[Float]](1).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](2).get,
      expectedOutput.get[Tensor[Float]](2).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](3).get,
      expectedOutput.get[Tensor[Float]](3).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](4).get,
      expectedOutput.get[Tensor[Float]](4).get) should be(true)
    Equivalent.nearequals(output.toTable.get[Tensor[Float]](5).get,
      expectedOutput.get[Tensor[Float]](5).get) should be(true)
  }
}

class FPNSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    val input = T()
    val feature1 = Tensor[Float](1, 1, 8, 8).apply1(_ => Random.nextFloat())
    val feature2 = Tensor[Float](1, 2, 4, 4).apply1(_ => Random.nextFloat())
    val feature3 = Tensor[Float](1, 4, 2, 2).apply1(_ => Random.nextFloat())
    input(1.0f) = feature1
    input(2.0f) = feature2
    input(3.0f) = feature3

    val fpn = new FPN[Float](inChannels = Array(1, 2, 4), outChannels = 2, topBlocks = 0)
      .setName("FPN")
    runSerializationTest(fpn, input)
  }
}
